{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "sys.path.append('./bert')\n",
    "import lm_model\n",
    "import tensorflow as tf\n",
    "import tokenization\n",
    "from optimization import create_optimizer\n",
    "import model_transfer as mtransfer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "config = lm_model.BASE_PARAMS\n",
    "tokenizer = tokenization.FullTokenizer('./model/vocab.txt')\n",
    "config['vocab_size'] = len(tokenizer.vocab)\n",
    "config['max_length'] = 64\n",
    "config['batch_size'] = 4\n",
    "vocab_file = './model/vocab.txt'\n",
    "train_file = './train.tfrecord'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# processor = lm_model.LMWikiDataProcessor()\n",
    "# ds = lm_model.LMDataSet('./model/vocab.txt', config['max_length'])\n",
    "# d = ds.get_ds('./train.tfrecord', config['batch_size'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# train_iterator = d.make_one_shot_iterator()\n",
    "# train_inputs = train_iterator.get_next()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# with tf.Session() as sess:\n",
    "#     train_inputs_eval = sess.run(train_inputs)\n",
    "#     print(train_inputs_eval)\n",
    "#     print(tokenizer.convert_ids_to_tokens(train_inputs_eval['text'][0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def make_feed_dict(model, inputs):\n",
    "    dicts = {\n",
    "        model.inp : inputs['text'], \n",
    "        model.inp_len : inputs['text_len'] \n",
    "    }\n",
    "    return dicts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# model = lm_model.LMModel(config, config['max_length'])\n",
    "# loss = model.loss(True)\n",
    "# train_op = create_optimizer(loss, 0.01, 100000, 10000, False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# loop_num = 0\n",
    "# with tf.Session() as sess:\n",
    "#     sess.run(tf.global_variables_initializer())\n",
    "#     while loop_num < 200:\n",
    "#         loop_num += 1\n",
    "#         inputs = sess.run(train_inputs)\n",
    "#         inputs_dict = make_feed_dict(model, inputs)\n",
    "#         loss_val = sess.run([loss, train_op], inputs_dict)\n",
    "#         print(loss_val)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(train_file, vocab_file, config, log_dir, pretrained=None):\n",
    "    graph = tf.Graph()\n",
    "    with graph.as_default():\n",
    "        ds = lm_model.LMDataSet(vocab_file, config['max_length'])\n",
    "        d = ds.get_ds(train_file, config['batch_size'])\n",
    "        train_iterator = d.make_one_shot_iterator()\n",
    "        train_inputs = train_iterator.get_next()\n",
    "        model = lm_model.LMModel(config, config['max_length'])\n",
    "        loss = model.loss(True)\n",
    "        train_op = create_optimizer(loss, config['learning_rate'], config['train_steps'], config['learning_rate_warmup_steps'], False)\n",
    "        partialSaver = None\n",
    "        if pretrained:\n",
    "            partialSaver = mtransfer.partial_transfer(pretrained)\n",
    "        sv = tf.train.Supervisor(graph=graph, logdir=log_dir)\n",
    "        with sv.managed_session(master='') as sess:\n",
    "            train_steps = config['train_steps']\n",
    "#             sess.run(tf.global_variables_initializer())\n",
    "            if partialSaver:\n",
    "                partialSaver.restore(sess, pretraied)\n",
    "            for step in range(train_steps):\n",
    "                if sv.should_stop():\n",
    "                    break\n",
    "                try:\n",
    "                    inputs = sess.run(train_inputs)\n",
    "                    feed_dicts = make_feed_dict(model, inputs)\n",
    "                    loss_val, _ = sess.run([loss, train_op], feed_dict=feed_dicts)\n",
    "                    if (step+1)%1:\n",
    "                        print('====> [{}/{}]\\tloss:{.3f}'.format(step, train_steps, loss_val))\n",
    "                except:\n",
    "                    sv.saver.save(sess, './final_model', global_step=(step+1))\n",
    "            sess.run(tf.global_variables_initializer())\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/higgs/anaconda2/envs/py36/lib/python3.6/site-packages/tensorflow/python/ops/gradients_impl.py:112: UserWarning: Converting sparse IndexedSlices to a dense Tensor of unknown shape. This may consume a large amount of memory.\n",
      "  \"Converting sparse IndexedSlices to a dense Tensor of unknown shape. \"\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From <ipython-input-9-b18c3d44f696>:14: Supervisor.__init__ (from tensorflow.python.training.supervisor) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please switch to tf.train.MonitoredTrainingSession\n",
      "INFO:tensorflow:Restoring parameters from ./log/model.ckpt-0\n",
      "INFO:tensorflow:Running local_init_op.\n",
      "INFO:tensorflow:Done running local_init_op.\n",
      "INFO:tensorflow:Starting standard services.\n",
      "INFO:tensorflow:Saving checkpoint to path ./log/model.ckpt\n",
      "INFO:tensorflow:Starting queue runners.\n",
      "INFO:tensorflow:global_step/sec: 0\n",
      "INFO:tensorflow:global_step/sec: 1.31669\n",
      "INFO:tensorflow:global_step/sec: 1.42501\n",
      "INFO:tensorflow:global_step/sec: 1.48336\n",
      "INFO:tensorflow:global_step/sec: 1.49166\n",
      "INFO:tensorflow:Saving checkpoint to path ./log/model.ckpt\n",
      "INFO:tensorflow:global_step/sec: 1.49995\n",
      "INFO:tensorflow:global_step/sec: 1.51672\n",
      "INFO:tensorflow:global_step/sec: 1.50831\n",
      "INFO:tensorflow:global_step/sec: 1.49169\n",
      "INFO:tensorflow:global_step/sec: 1.47494\n",
      "INFO:tensorflow:Saving checkpoint to path ./log/model.ckpt\n",
      "INFO:tensorflow:Saving checkpoint to path ./log/model.ckpt\n",
      "INFO:tensorflow:Saving checkpoint to path ./log/model.ckpt\n",
      "INFO:tensorflow:Saving checkpoint to path ./log/model.ckpt\n"
     ]
    }
   ],
   "source": [
    "train('./train.tfrecord', './model/vocab.txt', config, './log')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# graph = tf.Graph()\n",
    "# with graph.as_default():\n",
    "#     ds = lm_model.LMDataSet(vocab_file, config['max_length'])\n",
    "#     d = ds.get_ds(train_file, config['batch_size'])\n",
    "#     train_iterator = d.make_one_shot_iterator()\n",
    "#     train_inputs = train_iterator.get_next()\n",
    "#     model = lm_model.LMModel(config, config['max_length'])\n",
    "#     loss = model.loss(True)\n",
    "#     train_op = create_optimizer(loss, config['learning_rate'], config['train_steps'], config['learning_rate_warmup_steps'], False)\n",
    "#     partialSaver = None\n",
    "#     sv = tf.train.Supervisor(graph=graph, logdir='./log')\n",
    "#     with sv.managed_session(master='') as sess:\n",
    "# #         sess.run(tf.global_variables_initializer())\n",
    "#         inputs = sess.run(train_inputs)\n",
    "#         print(inputs)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
